{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torchvision\n",
    "from torch.utils import data\n",
    "from torchvision import transforms\n",
    "\n",
    "# 获取数据集\n",
    "mnist_train = torchvision.datasets.FashionMNIST(\n",
    "    root=\"../../data\", train=True, transform=transforms.ToTensor(), download=False)\n",
    "mnist_test = torchvision.datasets.FashionMNIST(\n",
    "    root=\"../../data\", train=False, transform=transforms.ToTensor(), download=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(mnist_train), len(mnist_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 结构是这样的，minst是一个元组的集合，元组的第一个元素是feature，第二个元素是label\n",
    "mnist_train[0][0], mnist_train[0][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_fashion_mnist_labels(labels):  \n",
    "    \"\"\"返回Fashion-MNIST数据集的文本标签输入是多个标签\"\"\"\n",
    "    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',\n",
    "                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']\n",
    "    return [text_labels[int(i)] for i in labels]\n",
    "\n",
    "get_fashion_mnist_labels([1,2,2,2,3,4,5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_img(img:torch.Tensor, rows = 28, cols = 28): # 显示图片\n",
    "    plt.imshow(img.reshape((rows, cols)),cmap=\"gray\")\n",
    "show_img(mnist_train[0][0]), get_fashion_mnist_labels([mnist_train[0][1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 小批量生成\n",
    "batch_size = 256\n",
    "num_works = 0\n",
    "train_iter = data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_works)\n",
    "test_iter = data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_works)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 测试进程数量速度\n",
    "import time\n",
    "def timer(fun):\n",
    "    def wrapper(*args,**kwargs):\n",
    "        start = time.time()\n",
    "        res = fun(*args,**kwargs)\n",
    "        end = time.time()\n",
    "        print(f\"[{fun.__name__}] run {end-start} seconds\")\n",
    "        return res\n",
    "    return wrapper\n",
    "\n",
    "@timer\n",
    "def speed_test():\n",
    "    for x, y in train_iter:\n",
    "        continue\n",
    "\n",
    "speed_test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 测试\n",
    "for x, y in train_iter:\n",
    "    print(x.shape,y.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_n = 784\n",
    "output_n = 10\n",
    "# 权重\n",
    "W = torch.normal(0,0.01,(input_n, output_n),requires_grad=True)\n",
    "b = torch.zeros((10,1),requires_grad=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 根据net，X是(batch_size, 10)，我们要做的是对每一个样本求，softmax，也就是对每一行求，求完后依旧是(batch_size, 10)只不过成了概率\n",
    "def softmax(X):\n",
    "    x_exp = torch.exp(X)\n",
    "    return x_exp/x_exp.sum(dim=1, keepdim=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 根据batch_size,X是(batch_size, 1, 28, 28),要整成(batch_size, 786)\n",
    "def net(X:torch.Tensor):\n",
    "    return softmax(X.reshape(-1, W.shape[0]).matmul(W) + b.T)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.0000, -1.3863])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 交叉熵损失,这里y_hat,是(batch_size, 10)的概率值,y(batch_size)的一维向量\n",
    "def cross_entropy(y_hat, y):\n",
    "    return -torch.log(y_hat[range(len(y_hat)),y])\n",
    "# one-hot的特性只剩下估计值的负对数\n",
    "cross_entropy(torch.Tensor([[1,2],\n",
    "                            [3,4]]),[0,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.18359375\n"
     ]
    }
   ],
   "source": [
    "# 准确性\n",
    "def accuracy(y_hat:torch.Tensor, y):\n",
    "    fin_output = y_hat.argmax(dim=1)\n",
    "    cmp = fin_output.type(y.dtype) == y\n",
    "    return float(cmp.type(y.dtype).sum())/len(y)\n",
    "for x,y in test_iter:\n",
    "    print(accuracy(net(x),y))\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 优化器\n",
    "def sgd(params, lr, batch_size):\n",
    "    with torch.no_grad():\n",
    "        for param in params:\n",
    "            param -= lr * param.grad / batch_size\n",
    "            param.grad.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 10\n",
    "lr = 0.03\n",
    "def train():\n",
    "    for epoch in range(epochs):\n",
    "        for x,y in train_iter:\n",
    "            y_hat = net(x)\n",
    "            l = cross_entropy(y_hat, y) # 一维向量\n",
    "            l.sum().backward()\n",
    "            sgd([W,b], lr, batch_size)\n",
    "            \n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.826171875"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def emulate():\n",
    "    num = 0\n",
    "    sum = 0\n",
    "    for x,y in test_iter:\n",
    "        sum += accuracy(net(x),y)\n",
    "        # print(accuracy(net(x),y))\n",
    "        num += 1\n",
    "    return sum/num\n",
    "\n",
    "emulate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict:shirt\n",
      "actrully:shirt\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAhC0lEQVR4nO3dfWyV9f3/8dfp3WlL21PaQm+kYAsqKjfLUCtT+aI03Cwxosx5twSMgeiKGTKn6aKibkk3ljjjwvCfDWYm3iUC0SwkgrbMCRhQJMTRAasCoS03oz1t6R3t9fuDn92qgHw+np53W56P5EroOefV69OrV3lxcc55NxQEQSAAAOIswXoBAIBLEwUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAE0nWC/i63t5eHT16VJmZmQqFQtbLAQA4CoJALS0tKioqUkLC+a9zBl0BHT16VMXFxdbLAAB8R4cPH9aYMWPOe/+gK6DMzEzrJWAYyMvLi9u+Tpw4EZf9TJgwwTnT2trqta+GhgavHPC/vu3v8wF7DmjVqlW6/PLLlZqaqrKyMn388ccXleO/3RALCQkJcdviJTEx0Xkb7F8Thrdv+/t8QM60N954Q8uXL9eKFSv0ySefaOrUqZozZ46OHTs2ELsDAAxBA1JAL7zwghYvXqwHH3xQ11xzjV5++WWlp6frz3/+80DsDgAwBMW8gLq6urRr1y6Vl5f/dycJCSovL9e2bdu+8fjOzk5Fo9F+GwBg+It5AZ04cUI9PT3Kz8/vd3t+fv45n9isqqpSJBLp23gFHABcGsyfbaysrFRzc3PfdvjwYeslAQDiIOYvw87Ly1NiYqIaGxv73d7Y2KiCgoJvPD4cDiscDsd6GQCAQS7mV0ApKSmaNm2atmzZ0ndbb2+vtmzZounTp8d6dwCAIWpA3oi6fPlyLVy4UNddd51uuOEGvfjii2pra9ODDz44ELsDAAxBA1JA99xzj44fP65nnnlGDQ0N+t73vqdNmzZ944UJAIBLVygIgsB6Ef8rGo0qEolYL2PI8pkkMchOgZi4+eabvXI/+tGPnDM+46OSktz/7ZeRkeGc+etf/+qckaT169d75Vxxvg5vzc3NysrKOu/95q+CAwBcmiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJgYkGnYsDMcBzWOHz/eOXP33Xd77Wvfvn3OmdzcXOdMXl6ec+aTTz5xztx2223OGUk6deqUc6a6uto5MxzPV1w8roAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACZCwSAbRxuNRhWJRKyXcUlJS0vzypWVlTlnfvKTnzhnwuGwc+azzz5zzkjSyZMnnTNZWVnOmWg06pzJzs52zowdO9Y5I0mZmZnOmfb2dufMxo0bnTObN292zsBGc3PzBX8+uAICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABggmGkw8yzzz7rnBk5cqTXvnyGT/pkjh496pw5deqUc8Y3V1pa6pw5fvy4c8ZnsGhKSopzRpISExOdMxkZGXHJ9Pb2Omd6enqcM5L0xBNPeOVwFsNIAQCDEgUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABNJ1gsYihIS3HvbZ4Didddd55wZMWKEc2bfvn3OGUk6ceKEcyYtLc054zMs9cyZM84ZSerq6nLOdHZ2OmcKCwudM//+97+dM/n5+c4ZSRccIHk+R44ccc4cO3bMOfODH/zAOdPW1uackaRZs2Y5Z7Zs2eK1r0sRV0AAABMUEADARMwL6Nlnn1UoFOq3TZw4Mda7AQAMcQPyHNC1116rzZs3/3cnSTzVBADob0CaISkpSQUFBQPxqQEAw8SAPAe0f/9+FRUVqbS0VA888IAOHTp03sd2dnYqGo322wAAw1/MC6isrExr167Vpk2btHr1atXV1emWW25RS0vLOR9fVVWlSCTStxUXF8d6SQCAQSjmBTRv3jzdfffdmjJliubMmaO//e1vampq0ptvvnnOx1dWVqq5ublvO3z4cKyXBAAYhAb81QHZ2dm68sordeDAgXPeHw6HFQ6HB3oZAIBBZsDfB9Ta2qqDBw96vfMbADB8xbyAHn/8cdXU1OiLL77QRx99pDvvvFOJiYm67777Yr0rAMAQFvP/gjty5Ijuu+8+nTx5UqNGjdLNN9+s7du3a9SoUbHeFQBgCIt5Ab3++uux/pSDjs9gUR+lpaXOmY8//tg5M2bMGOeMJHV3dztncnJynDNBEDhnOjo6nDOS3zDX5ORk50xqampc9pObm+uckaTExETnTENDg3PGZ+ipz/d2z549zhlJmjRpknOGYaQXj1lwAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATAz4L6SDv7S0NOdMNBqNS8ZXV1eXc+b06dMDsJJzy87Ods74DMc8ceKEc8ZnaKzv4Nzjx487Z1pbW50zkUjEOXO+X255Ib4DdzMyMrxyuDhcAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATDANO04mTJjgnElJSXHOzJ492zlz6NAh54wkHT161DnT3d3tnElIcP93ks9Ua0lKTU11zrS0tDhn2tranDNJSe4/runp6c4ZSUpOTnbO5OTkOGd8pk2PHDnSOeN7jpeWljpnJk6c6JzZt2+fc2Y44AoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACYaRxklJSYlz5syZM86Z3bt3O2duvPFG54wk/f3vf3fO5OXlOWd8BoS2trY6ZySpo6PDOdPT0+Oc8fmafIaynjhxwjkjSQUFBc6Zw4cPO2cSExOdM1deeaVz5rLLLnPOSNKXX37pnLn66qudMwwjBQAgjiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJhgGGmcTJgwwTnjMxgzGo06Z3zWJknhcNg54zNg9fjx484Zn2GaknT69GnnTFKS+49Re3u7c6a3t9c54zP0VJKampqcM/X19c6ZESNGOGeysrKcMz7nkCQ1NjY6Z3zPvUsRV0AAABMUEADAhHMBbd26VbfffruKiooUCoW0YcOGfvcHQaBnnnlGhYWFSktLU3l5ufbv3x+r9QIAhgnnAmpra9PUqVO1atWqc96/cuVKvfTSS3r55Ze1Y8cOjRgxQnPmzPF6PgMAMHw5P3s6b948zZs375z3BUGgF198UU899ZTuuOMOSdIrr7yi/Px8bdiwQffee+93Wy0AYNiI6XNAdXV1amhoUHl5ed9tkUhEZWVl2rZt2zkznZ2dikaj/TYAwPAX0wJqaGiQJOXn5/e7PT8/v+++r6uqqlIkEunbiouLY7kkAMAgZf4quMrKSjU3N/dthw8ftl4SACAOYlpAX70B6+tv3mpsbDzvm7PC4bCysrL6bQCA4S+mBVRSUqKCggJt2bKl77ZoNKodO3Zo+vTpsdwVAGCIc34VXGtrqw4cOND3cV1dnXbv3q2cnByNHTtWy5Yt069//WtdccUVKikp0dNPP62ioiLNnz8/lusGAAxxzgW0c+dO3XrrrX0fL1++XJK0cOFCrV27Vk888YTa2tq0ZMkSNTU16eabb9amTZu8Z1IBAIYn5wKaOXOmgiA47/2hUEjPP/+8nn/++e+0sOHm668MvBi7du1yzvgM0xwzZoxzRpISExOdMz5DIUeOHOmcOXXqlHNG8jt+KSkpzhmfN2Y3Nzc7Z3zOO+ns2ydc+Rzzrq4u58y0adOcMz6DUiUpOTnZOZObm+u1r0uR+avgAACXJgoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACedp2Dj7W1xd+UxZzs7Ods74TLb+7LPPnDOS39fkIynJ/TS90MT2C/H5tSE+54PP1+STaWpqcs5IftO6z5w545wZMWKEc8bnHO/s7HTOSPL6Dc0+09HT09OdM/H6+RtIXAEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwwTBSD+PGjXPO9Pb2Omd8hjvm5uY6ZyKRiHNGkjIzM50zPT09zpmMjAznjM/x9s21trY6Z3yGYyYmJjpnfL5Hkt8QzuTkZOdMWlqac8bneLe0tDhnJCkajTpnLrvsMueMz4DVf/3rX86ZwYYrIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYYRuohJyfHOZOU5H6oi4uLnTM+g1J9hlxK0qlTp5wzPkMuffbjM/RUkrKzs50zoVDIOTNy5EjnjI/29nav3OnTp50z3d3dzpmGhgbnzN69e50z4XDYOSP5DUtNT093zpSWljpnGEYKAIAnCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJhhG6uHMmTPOmS+++MI54zNY1GeIZF1dnXNGklJTU50zPoNPu7q64rIfyW+Iqc8QTp+Mz/EOgsA5I/kNgM3NzXXOnDx50jnjM4QzPz/fOSNJEyZMcM40NjY6Zzo7O50zwwFXQAAAExQQAMCEcwFt3bpVt99+u4qKihQKhbRhw4Z+9y9atEihUKjfNnfu3FitFwAwTDgXUFtbm6ZOnapVq1ad9zFz585VfX193/baa699p0UCAIYf5xchzJs3T/PmzbvgY8LhsAoKCrwXBQAY/gbkOaDq6mqNHj1aV111lR555JELvtKls7NT0Wi03wYAGP5iXkBz587VK6+8oi1btui3v/2tampqNG/evPO+vLWqqkqRSKRvKy4ujvWSAACDUMzfB3Tvvff2/Xny5MmaMmWKxo8fr+rqas2aNesbj6+srNTy5cv7Po5Go5QQAFwCBvxl2KWlpcrLy9OBAwfOeX84HFZWVla/DQAw/A14AR05ckQnT55UYWHhQO8KADCEOP8XXGtra7+rmbq6Ou3evVs5OTnKycnRc889pwULFqigoEAHDx7UE088oQkTJmjOnDkxXTgAYGhzLqCdO3fq1ltv7fv4q+dvFi5cqNWrV2vPnj36y1/+oqamJhUVFWn27Nn61a9+pXA4HLtVAwCGvFDgO61wgESjUUUiEetlXFLGjh3rlfvxj3/snNm5c6dzJp7/feszaLajo8M5k5GR4ZxJSnJ/zVB7e7tzRpJCoZBzxuc4TJo0yTmzdetW50x9fb1zRpKSk5OdM7W1tV77Go6am5sv+Lw+s+AAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACZi/iu5MfQcP348bvvymS7sw2dytOS3Pp8J2qmpqc4Zn8nWCQl+/8ZMTEx0zvgc88zMTOfMRx995JwZZEP/8f9xBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEw0g9hEIh54zPMMR47aekpMQ5I/kNuvQZWNnR0eGc8RmmKfl9TT5DQn3W53M++A5l9fmafKSkpDhnRo8e7ZxpbGx0zkh+50Nvb6/Xvi5FXAEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwwTBSDz4DPwfzfjo7O71ybW1tzpnu7m7nTHZ2tnMmNTXVOSNJGRkZzpkRI0Y4Z5KTk50zLS0tzhmf75Ekpaenx2VfPueD7/fWR7x+Bi9VXAEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwwTDSQSwUCjlnfIYnjhw50jkjSZdffrlzxmdwZ15ennMmNzfXOSNJvb29zploNOqc6enpiUumvb3dOSNJCQnx+bepzzk+atQo58yXX37pnJEYRjrQuAICAJiggAAAJpwKqKqqStdff70yMzM1evRozZ8/X7W1tf0e09HRoYqKCuXm5iojI0MLFixQY2NjTBcNABj6nAqopqZGFRUV2r59u9577z11d3dr9uzZ/X4R1WOPPaZ33nlHb731lmpqanT06FHdddddMV84AGBoc3oRwqZNm/p9vHbtWo0ePVq7du3SjBkz1NzcrD/96U9at26dbrvtNknSmjVrdPXVV2v79u268cYbY7dyAMCQ9p2eA2pubpYk5eTkSJJ27dql7u5ulZeX9z1m4sSJGjt2rLZt23bOz9HZ2aloNNpvAwAMf94F1Nvbq2XLlummm27SpEmTJEkNDQ1KSUlRdnZ2v8fm5+eroaHhnJ+nqqpKkUikbysuLvZdEgBgCPEuoIqKCu3du1evv/76d1pAZWWlmpub+7bDhw9/p88HABgavN6IunTpUr377rvaunWrxowZ03d7QUGBurq61NTU1O8qqLGxUQUFBef8XOFwWOFw2GcZAIAhzOkKKAgCLV26VOvXr9f777+vkpKSfvdPmzZNycnJ2rJlS99ttbW1OnTokKZPnx6bFQMAhgWnK6CKigqtW7dOGzduVGZmZt/zOpFIRGlpaYpEInrooYe0fPly5eTkKCsrS48++qimT5/OK+AAAP04FdDq1aslSTNnzux3+5o1a7Ro0SJJ0u9//3slJCRowYIF6uzs1Jw5c/THP/4xJosFAAwfTgV0MYP5UlNTtWrVKq1atcp7UTgrXoMQfYY7StLx48fjkmlpaXHOtLa2Omckqbu72znzv2/EHsj9+PAZrir5DSM9c+aMc+aaa65xzmRmZjpnMDgxCw4AYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYMLrN6JieElOTvbK1dfXx3gl59bT0+Oc8ZmgLflNgW5qanLOpKWlOWfa29vjkpGkjIwM54zPVPA9e/Y4Z87325UHi1Ao5JyJ1+T7wYYrIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYYRgrl5+d75bq6upwzqampzpn09HTnTGJionPGN+fzNSUluf/o+azN59hJfl9TvPaTk5MzACuBBa6AAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmGAYaZyEQiHnTBAEzhmfgZVpaWnOGUnq7Ox0zsRryGU8+Rxz32Gp8ZKQ4P5v097eXueMz0Db3Nxc5wwGJ66AAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmGAYaZzEaxhpVlaWc6awsNA5I0n/+c9/nDM+AyvjlZGkpCT3HwmfwZ0+w0h99nPmzBnnjOR3/OI1wPSKK65wzuTk5DhnJL9z3Oc49PT0OGeGA66AAAAmKCAAgAmnAqqqqtL111+vzMxMjR49WvPnz1dtbW2/x8ycOVOhUKjf9vDDD8d00QCAoc+pgGpqalRRUaHt27frvffeU3d3t2bPnq22trZ+j1u8eLHq6+v7tpUrV8Z00QCAoc/pGddNmzb1+3jt2rUaPXq0du3apRkzZvTdnp6eroKCgtisEAAwLH2n54Cam5slffMVJq+++qry8vI0adIkVVZW6vTp0+f9HJ2dnYpGo/02AMDw5/0y7N7eXi1btkw33XSTJk2a1Hf7/fffr3HjxqmoqEh79uzRk08+qdraWr399tvn/DxVVVV67rnnfJcBABiivAuooqJCe/fu1Ycfftjv9iVLlvT9efLkySosLNSsWbN08OBBjR8//hufp7KyUsuXL+/7OBqNqri42HdZAIAhwquAli5dqnfffVdbt27VmDFjLvjYsrIySdKBAwfOWUDhcFjhcNhnGQCAIcypgIIg0KOPPqr169erurpaJSUl35rZvXu3JP932wMAhienAqqoqNC6deu0ceNGZWZmqqGhQZIUiUSUlpamgwcPat26dfrhD3+o3Nxc7dmzR4899phmzJihKVOmDMgXAAAYmpwKaPXq1ZLOvtn0f61Zs0aLFi1SSkqKNm/erBdffFFtbW0qLi7WggUL9NRTT8VswQCA4cH5v+AupLi4WDU1Nd9pQQCASwPTsOPEZ7K1j1OnTjlnduzY4bWvG2+80TmTmZnpnPF5b1h3d7dzRvKbzuzzvfWZjj5ixAjnTEpKinNGkkaNGuWc6erqcs6kpqY6Z756XtmFz1RrX76T2C9FDCMFAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgIhTEa0rmRYpGo4pEItbLwEWYOHGic+bbfoPuufgMMM3KynLOSFJ2drZzJinJfaavz7DUlpaWuOxHks6cOeOc+er3g7n4/PPPnTPHjh1zzsBGc3PzBX8WuQICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAn3IVYDbJCNpsMF9PT0OGd8Zoz5zDPr6upyzkhSZ2enc8bnOMTra/KdBRev721vb69zBkPHt/19PugKyGfgImzs378/LhkAQ1NLS8sFh0sPumnYvb29Onr0qDIzMxUKhfrdF41GVVxcrMOHD3tPOx4OOA5ncRzO4jicxXE4azAchyAI1NLSoqKiIiUknP+ZnkF3BZSQkPCtI/uzsrIu6RPsKxyHszgOZ3EczuI4nGV9HC7m1+rwIgQAgAkKCABgYkgVUDgc1ooVKxQOh62XYorjcBbH4SyOw1kch7OG0nEYdC9CAABcGobUFRAAYPiggAAAJiggAIAJCggAYGLIFNCqVat0+eWXKzU1VWVlZfr444+tlxR3zz77rEKhUL9t4sSJ1ssacFu3btXtt9+uoqIihUIhbdiwod/9QRDomWeeUWFhodLS0lReXj4sR/5823FYtGjRN86PuXPn2ix2gFRVVen6669XZmamRo8erfnz56u2trbfYzo6OlRRUaHc3FxlZGRowYIFamxsNFrxwLiY4zBz5sxvnA8PP/yw0YrPbUgU0BtvvKHly5drxYoV+uSTTzR16lTNmTNHx44ds15a3F177bWqr6/v2z788EPrJQ24trY2TZ06VatWrTrn/StXrtRLL72kl19+WTt27NCIESM0Z84cdXR0xHmlA+vbjoMkzZ07t9/58dprr8VxhQOvpqZGFRUV2r59u9577z11d3dr9uzZamtr63vMY489pnfeeUdvvfWWampqdPToUd11112Gq469izkOkrR48eJ+58PKlSuNVnwewRBwww03BBUVFX0f9/T0BEVFRUFVVZXhquJvxYoVwdSpU62XYUpSsH79+r6Pe3t7g4KCguB3v/td321NTU1BOBwOXnvtNYMVxsfXj0MQBMHChQuDO+64w2Q9Vo4dOxZICmpqaoIgOPu9T05ODt56662+x/zzn/8MJAXbtm2zWuaA+/pxCIIg+L//+7/gZz/7md2iLsKgvwLq6urSrl27VF5e3ndbQkKCysvLtW3bNsOV2di/f7+KiopUWlqqBx54QIcOHbJekqm6ujo1NDT0Oz8ikYjKysouyfOjurpao0eP1lVXXaVHHnlEJ0+etF7SgGpubpYk5eTkSJJ27dql7u7ufufDxIkTNXbs2GF9Pnz9OHzl1VdfVV5eniZNmqTKykqdPn3aYnnnNeiGkX7diRMn1NPTo/z8/H635+fna9++fUarslFWVqa1a9fqqquuUn19vZ577jndcsst2rt3rzIzM62XZ6KhoUGSznl+fHXfpWLu3Lm66667VFJSooMHD+qXv/yl5s2bp23btikxMdF6eTHX29urZcuW6aabbtKkSZMknT0fUlJSlJ2d3e+xw/l8ONdxkKT7779f48aNU1FRkfbs2aMnn3xStbW1evvttw1X29+gLyD817x58/r+PGXKFJWVlWncuHF688039dBDDxmuDIPBvffe2/fnyZMna8qUKRo/fryqq6s1a9Ysw5UNjIqKCu3du/eSeB70Qs53HJYsWdL358mTJ6uwsFCzZs3SwYMHNX78+Hgv85wG/X/B5eXlKTEx8RuvYmlsbFRBQYHRqgaH7OxsXXnllTpw4ID1Usx8dQ5wfnxTaWmp8vLyhuX5sXTpUr377rv64IMP+v36loKCAnV1dampqanf44fr+XC+43AuZWVlkjSozodBX0ApKSmaNm2atmzZ0ndbb2+vtmzZounTpxuuzF5ra6sOHjyowsJC66WYKSkpUUFBQb/zIxqNaseOHZf8+XHkyBGdPHlyWJ0fQRBo6dKlWr9+vd5//32VlJT0u3/atGlKTk7udz7U1tbq0KFDw+p8+LbjcC67d++WpMF1Pli/CuJivP7660E4HA7Wrl0bfP7558GSJUuC7OzsoKGhwXppcfXzn/88qK6uDurq6oJ//OMfQXl5eZCXlxccO3bMemkDqqWlJfj000+DTz/9NJAUvPDCC8Gnn34afPnll0EQBMFvfvObIDs7O9i4cWOwZ8+e4I477ghKSkqC9vZ245XH1oWOQ0tLS/D4448H27ZtC+rq6oLNmzcH3//+94Mrrrgi6OjosF56zDzyyCNBJBIJqqurg/r6+r7t9OnTfY95+OGHg7Fjxwbvv/9+sHPnzmD69OnB9OnTDVcde992HA4cOBA8//zzwc6dO4O6urpg48aNQWlpaTBjxgzjlfc3JAooCILgD3/4QzB27NggJSUluOGGG4Lt27dbLynu7rnnnqCwsDBISUkJLrvssuCee+4JDhw4YL2sAffBBx8Ekr6xLVy4MAiCsy/Ffvrpp4P8/PwgHA4Hs2bNCmpra20XPQAudBxOnz4dzJ49Oxg1alSQnJwcjBs3Lli8ePGw+0faub5+ScGaNWv6HtPe3h789Kc/DUaOHBmkp6cHd955Z1BfX2+36AHwbcfh0KFDwYwZM4KcnJwgHA4HEyZMCH7xi18Ezc3Ntgv/Gn4dAwDAxKB/DggAMDxRQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAw8f8AEsKf5Xca1vIAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import random\n",
    "def emulate_vision():\n",
    "    length = len(mnist_test)\n",
    "    rand = int(random.random()*length)\n",
    "    o = net(mnist_test[rand][0])\n",
    "    y_hat = o.argmax()\n",
    "    show_img(mnist_test[rand][0])\n",
    "    y = mnist_test[rand][1]\n",
    "    print(f\"predict:{get_fashion_mnist_labels([y_hat])[0]}\\nactrully:{get_fashion_mnist_labels([y])[0]}\")\n",
    "\n",
    "emulate_vision()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
